# Introduction ======

#We are going to put the code used to make figure 5 (EEG-behaviour interaction) here

# Set up ----

## Load packages and metadata ---------

#Load packages
pacman::p_load(tidyverse,patchwork,ghibli,eegUtils,gtsummary,
               lme4)


#Get our default settings
source("./eLife Submission Scripts/Analysis-Common-Utilities.R")


## Load our data ----- 

#Load the EEG summary data
d_sum = 
  readr::read_rds("./eLife Submission Data/sleep_study_eeg_summary_data_z.rds")


#Remove the angle data
d_sum = 
  d_sum |>
  filter(measure != "itpc_overlap_angle")

#Load the data fitted in Analysis-Interaction-Clusters.R
fig5_clusters = readr::read_rds("./eLife Submission Data/sleep_study_interaction_clusters.rds")
fig5_models   = readr::read_rds("./eLife Submission Data/sleep_study_interaction_models.rds")



# Figure 6 ======

## Fig 6A: Illustrative Scatter Plot =====

#Lets make scatter plots to illustrate the nature of the interactions we see

#Use a subset of measures based on our results
exposures_eeg = c("cons_REM","spin_amp","so_amp","itpc_overlap_mag_z")

#Very simple scatter
mi = 
  d_sum |>
  filter(measure %in% exposures_eeg | (measure == "cons" & stage == "REM") ) |>
  unnest(data) |>
  filter(electrode == "Cz") |>
  left_join(meta |> select(subject, hits_morning),by = "subject") |>
  mutate(m_s = interaction(measure,stage,sep = "_")) |>
  mutate(m_s = map_chr(m_s,str_remove,"_N2&N3")) |>
  mutate(m_s = factor(m_s, levels = c("cons_REM","spin_amp","so_amp","itpc_overlap_mag_z"))) |>
  ungroup() |>
  group_by(m_s) |>
  nest() |>
  mutate(model = map(data,~ lmer(hits_morning ~ value*group + age_eeg + gender + (1|family),
                                 data = .x))) %>%
  mutate(plot  = map2(model,m_s, ~ sjPlot::plot_model(.x, type = "pred",terms = c("value","group"),
                                                      show.data = T, dot.size = 0.5,
                                                      colors = cols,alpha = 0.4,line.size = 0.5) +
                        theme_bw() +
                        theme(panel.grid = element_blank(),
                              axis.text  = element_text(colour = "grey20",size = 6),
                              axis.title = element_text(colour = "black",size = 6),
                              axis.text.y = element_text(colour = "grey20",size = 6),
                              axis.text.x = element_text(colour = "grey20",size = 6),
                              plot.title  = element_text(colour = "black",size = 6),
                              plot.subtitle = element_text(colour = "black",size = 6),
                              legend.position = "none") +
                        labs(title = NULL,subtitle = .y, y = "Hits (n)") +
                        scale_y_continuous(breaks = c(0,5,10, 15)) +
                        coord_cartesian(ylim = c(0,17))
                      
  ) )


p_interaction_scatter = wrap_plots(mi$plot,ncol = 4)



## Fig 6B: Topoplots ======


clus_use = 
  fig5_clusters |>
  dplyr::select(-data) |>
  unnest(clus) |>
  unnest(electrodes) |>
  filter(p.value < 0.01) |>
  rename(electrode = electrodes,
         amplitude = statistic) |>
  left_join(topo, by = "electrode")

#Topoplots for the subset of measures that have group-level differences
p_interaction_topo = 
  fig5_models |>
  ungroup() |>
  unnest(e_table) |>
  rename(amplitude = estimate) |>
  left_join(topo, by = "electrode") |>
  dplyr::select(m_s,electrode,x,y,amplitude) |>
  ggplot(aes(x = x,
             y = y,
             z = amplitude,
             fill = amplitude,
             label = electrode)) +
  geom_topo(grid_res = 200,
            colour = "white",
            size = 0.1,
            interp_limit = "head",
            chan_markers = "point",
            chan_size = 0.25,
            head_size = 0.5,
            method = "gam", breaks = 10) + 
  geom_point(data = clus_use, 
             aes(x = x, y = y), colour = "black",fill = "black", size = 1.5) +
  geom_point(data = clus_use, 
             aes(x = x, y = y), colour = "white",fill = "white", size = 1.2) +
  scale_fill_distiller(palette = "RdBu",limits = c(-0.75,0.75),
                       oob = scales::squish) +
  facet_wrap(~m_s,ncol = 4)+
  theme_void() + 
  theme(panel.grid = element_blank(),
        # panel.border = element_blank(),
        axis.text  = element_blank(),
        axis.title = element_blank(),
        axis.text.y = element_blank(),
        axis.text.x = element_blank(),
        strip.text.y = element_text(colour = "black",size = 6,angle = 0),
        strip.text.x = element_text(colour = "black",size = 6,angle = 0),
        legend.title = element_text(colour = "black",size = 6),
        legend.text = element_text(colour = "black",size = 6),
        legend.position = "bottom") +
  coord_equal() + 
  labs(fill = "22q > Sib Difference") 


# Save -----

figure_5 = (p_interaction_scatter/p_interaction_topo) + plot_layout(heights = c(1,1.5)) + plot_annotation(tag_levels = "A")

ggsave("./Figures/figure_5.pdf",plot = figure_5, width = 14, height = 8, units = "cm")
